{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {},
   "source": [
    "This example requires the following dependencies to be installed:\n",
    "pip install lightly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install lightly"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2",
   "metadata": {},
   "source": [
    "Note: The model and training settings do not follow the reference settings\n",
    "from the paper. The settings are chosen such that the example can easily be\n",
    "run on a small dataset with a single GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pytorch_lightning as pl\n",
    "import torch\n",
    "import torchvision\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from lightly.loss import NTXentLoss\n",
    "from lightly.models import utils\n",
    "from lightly.models.modules import DenseCLProjectionHead\n",
    "from lightly.transforms import DenseCLTransform\n",
    "from lightly.utils.scheduler import cosine_schedule"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DenseCL(pl.LightningModule):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        resnet = torchvision.models.resnet18()\n",
    "        self.backbone = nn.Sequential(*list(resnet.children())[:-2])\n",
    "        self.projection_head_global = DenseCLProjectionHead(512, 512, 128)\n",
    "        self.projection_head_local = DenseCLProjectionHead(512, 512, 128)\n",
    "\n",
    "        self.backbone_momentum = copy.deepcopy(self.backbone)\n",
    "        self.projection_head_global_momentum = copy.deepcopy(\n",
    "            self.projection_head_global\n",
    "        )\n",
    "        self.projection_head_local_momentum = copy.deepcopy(self.projection_head_local)\n",
    "        self.pool = nn.AdaptiveAvgPool2d((1, 1))\n",
    "\n",
    "        utils.deactivate_requires_grad(self.backbone_momentum)\n",
    "        utils.deactivate_requires_grad(self.projection_head_global_momentum)\n",
    "        utils.deactivate_requires_grad(self.projection_head_local_momentum)\n",
    "\n",
    "        self.criterion_global = NTXentLoss(memory_bank_size=(4096, 128))\n",
    "        self.criterion_local = NTXentLoss(memory_bank_size=(4096, 128))\n",
    "\n",
    "    def forward(self, x):\n",
    "        query_features = self.backbone(x)\n",
    "        query_global = self.pool(query_features).flatten(start_dim=1)\n",
    "        query_global = self.projection_head_global(query_global)\n",
    "        query_features = query_features.flatten(start_dim=2).permute(0, 2, 1)\n",
    "        query_local = self.projection_head_local(query_features)\n",
    "        # Shapes: (B, H*W, C), (B, D), (B, H*W, D)\n",
    "        return query_features, query_global, query_local\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def forward_momentum(self, x):\n",
    "        key_features = self.backbone(x)\n",
    "        key_global = self.pool(key_features).flatten(start_dim=1)\n",
    "        key_global = self.projection_head_global(key_global)\n",
    "        key_features = key_features.flatten(start_dim=2).permute(0, 2, 1)\n",
    "        key_local = self.projection_head_local(key_features)\n",
    "        return key_features, key_global, key_local\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)\n",
    "        utils.update_momentum(model.backbone, model.backbone_momentum, m=momentum)\n",
    "        utils.update_momentum(\n",
    "            model.projection_head_global,\n",
    "            model.projection_head_global_momentum,\n",
    "            m=momentum,\n",
    "        )\n",
    "        utils.update_momentum(\n",
    "            model.projection_head_local,\n",
    "            model.projection_head_local_momentum,\n",
    "            m=momentum,\n",
    "        )\n",
    "\n",
    "        x_query, x_key = batch[0]\n",
    "        query_features, query_global, query_local = self(x_query)\n",
    "        key_features, key_global, key_local = self.forward_momentum(x_key)\n",
    "\n",
    "        key_local = utils.select_most_similar(query_features, key_features, key_local)\n",
    "        query_local = query_local.flatten(end_dim=1)\n",
    "        key_local = key_local.flatten(end_dim=1)\n",
    "\n",
    "        loss_global = self.criterion_global(query_global, key_global)\n",
    "        loss_local = self.criterion_local(query_local, key_local)\n",
    "        lambda_ = 0.5\n",
    "        loss = (1 - lambda_) * loss_global + lambda_ * loss_local\n",
    "        return loss\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        optim = torch.optim.SGD(self.parameters(), lr=0.06)\n",
    "        return optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = DenseCL()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = DenseCLTransform(input_size=32)\n",
    "dataset = torchvision.datasets.CIFAR10(\n",
    "    \"datasets/cifar10\", download=True, transform=transform\n",
    ")\n",
    "# or create a dataset from a folder containing images or videos:\n",
    "# dataset = LightlyDataset(\"path/to/folder\", transform=transform)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloader = torch.utils.data.DataLoader(\n",
    "    dataset,\n",
    "    batch_size=256,\n",
    "    shuffle=True,\n",
    "    drop_last=True,\n",
    "    num_workers=8,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10",
   "metadata": {},
   "outputs": [],
   "source": [
    "accelerator = \"gpu\" if torch.cuda.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = pl.Trainer(max_epochs=10, devices=1, accelerator=accelerator)\n",
    "trainer.fit(model=model, train_dataloaders=dataloader)"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all",
   "main_language": "python",
   "notebook_metadata_filter": "-all"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
